from abc import ABCMeta, abstractmethod
from dataclasses import asdict, dataclass
from typing import Optional, Type, TypeVar

import torch
import torch.nn as nn

from .attention_mask import AttentionMask


@dataclass
class AttentionConfig:
    """Parameters required for all Attentions.
    Can accept and store extra parameters.
    """

    name: str  # the registered name for this attention mechanism
    dropout: float  # dropout probability


Self = TypeVar("Self", bound="Attention")


# Define the common interface, every attention block needs to derive from it
class Attention(nn.Module, metaclass=ABCMeta):
    r"""The base Attention mechanism, which is typically a sub-part of the multi-head attention"""

    _causal_mask: Optional[AttentionMask] = None

    @abstractmethod
    def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
        super().__init__()

        # Requires the inputs to be projected
        self.requires_input_projection = True

        # Whether the head dimension needs to be present (if not it can be folded into the batch dimension)
        self.requires_head_dimension = False

        # key padding mask and attention mask must be passed in as separate arguments instead of a merged attention mask
        self.requires_separate_masks = False

        # Requires that K and Q have the same sequence length
        self.requires_same_k_q_dimensions = False

        # Whether the attention owns the single head/multihead mechanism
        # so that the MHA wrapper should skip it
        self.requires_skip_multi_head = False

        # This attention requires a context length which is squared, often due to 2D pooling
        self.requires_squared_context = False

        # Whether this attention mechanism supports attention masks
        self.supports_attention_mask = True
        self.supports_key_padding_mask = False

    @classmethod
    def from_config(cls: Type[Self], config: AttentionConfig) -> Self:
        # Generate the class inputs from the config
        fields = asdict(config)

        # Skip all Nones so that default values are used
        fields = {k: v for k, v in fields.items() if v is not None}

        return cls(**fields)

    @abstractmethod
    def forward(
        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
    ) -> torch.Tensor:
        raise NotImplementedError

    @staticmethod
    def _maybe_pad_sequence(x: torch.Tensor, mask: torch.Tensor):
        """
        If the sequence is shorter than the mask, return a padded view
        """
        if x.shape[-2] != mask.shape[-1]:
            if x.shape[-2] > mask.shape[-1]:
                raise Exception(
                    "Sequence is bigger than the provided mask, cannot infer what to do with it."
                    "Please update your attention mask"
                )

            pad_size = (0, 0, 0, mask.shape[-1] - x.shape[-2], 0, 0)
            return torch.nn.functional.pad(x, pad_size, mode="constant", value=0.0)

        return x
